"""
PINF for the SFP equation

a*nabla_(p*x) + s**2/2 * Delta_p = 0, x:R^d

the exact solution:
p(x) = (a/(pi*s**2))**(d/2)*exp(-a*x**2/(s**2))
"""

import os
import time
import copy
import datetime
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as tk

#from torchdiffeq import odeint_adjoint as odeint
from torchdiffeq import odeint

torch.manual_seed(111)
torch.cuda.manual_seed_all(111)

def trace_df_dz(f, z):
    sum_diag = 0.
    for i in range(f.shape[1]):
        sum_diag += torch.autograd.grad(f[:, i].sum(), z, create_graph=True)[0].contiguous()[:, i].contiguous()

    return sum_diag.contiguous()


class MLP(torch.nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.net = self.create_net(layers)
        
    def create_net(self, layers):
        linears = torch.nn.ModuleList()
        for i in range(len(layers)-1):
            f = torch.nn.Linear(layers[i], layers[i+1])
            torch.nn.init.constant_(f.weight.data,0.0)
            linears.append(f)
            
        return linears
    
    def forward(self, x):
        for f in self.net[:-1]:
            x = torch.tanh(f(x))
        out = self.net[-1](x)
        
        return out
        

class CouplingLayer(torch.nn.Module):
    def __init__(self, mask, layers):
        super().__init__()
        self.mask = mask
        self.s_net = MLP(layers)
        self.t_net = MLP(layers)
        
    def forward(self, x, inverse=False):
        mx = x * self.mask
        s = self.s_net(mx)
        t = self.t_net(mx)
        if not inverse:
            u = mx + (1 - self.mask) * (x * torch.exp(s) + t)
            ldj = ((1 - self.mask) * s).sum(dim=-1)
        else:
            u = mx + (1 - self.mask) * (x - t) * torch.exp(-s)
            ldj = -((1 - self.mask) * s).sum(dim=-1)

        return u, ldj
    

class RealNVP(torch.nn.Module):
    def __init__(self, num_flows, layers, dim, device):
        super().__init__()
        self.num_flows = num_flows
        self.layers = layers
        self.dim = dim
        self.device = device
        self.flows = self.make_flows()
        
    def make_flows(self):
        flows = torch.nn.ModuleList()
        mask = (torch.arange(self.dim) % 2).type(torch.float32).to(self.device)
        for i in range(self.num_flows):
            flows.append(CouplingLayer(mask=mask, layers=self.layers).to(self.device))
            mask = 1 - mask
        return flows
    
    def forward(self, x, inverse=False):
        flows = list(reversed(self.flows)) if inverse else self.flows
        sum_ldj = 0
        for flow in flows:
            x, ldj = flow(x, inverse)
            sum_ldj += ldj
        return x, sum_ldj

    
class logp_net(torch.nn.Module):
    def __init__(self, num_flows, layers, dim, device):
        super().__init__()
        self.dim = dim
        self.device = device
        self.net = RealNVP(num_flows, layers, dim, device)
        self.pG = torch.distributions.MultivariateNormal(
            loc=torch.zeros(self.dim, device=self.device),
            covariance_matrix=torch.eye(self.dim, device=self.device))
    
    def inverse(self, z):
        x, sum_ldj = self.net(z, inverse=False)
        logp_x = self.pG.log_prob(z) - sum_ldj
        return x, logp_x
    
    def forward(self, x):
        z0, sum_ldj = self.net(x, inverse=True)
        logp_x = self.pG.log_prob(z0) + sum_ldj

        return logp_x
    
    
class ODEs(torch.nn.Module):
    def __init__(self, num_flows, layers, dim, device):
        super().__init__()
        self.dim = dim
        self.device = device
        self.phi = logp_net(num_flows, layers, dim, device)
        
    def predict(self, x0, logp_x0):
        bs = x0.shape[0]
        ts = torch.tensor([0., 1.]).type(torch.float32).to(self.device)

        x_t, logp_diff_t = odeint(
            self,
            (x0, logp_x0),
            ts,
            atol=1e-8,
            rtol=1e-8,
            method='rk4',
        )
        
        x1, logp_ode = x_t[-1], logp_diff_t[-1]
        logp_net = self.phi(x1)
        return logp_ode, logp_net
           
    def forward(self, t, states):
        x = states[0]
        logp_x = states[1]
        bs = x.shape[0]

        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            
            logp = self.phi(x)
            dlogp_dx = torch.autograd.grad(logp.sum(), x, create_graph=True)[0]

            dx_dt = - a * x - (s*s/2) * dlogp_dx
            dlogp_x_dt = - trace_df_dz(dx_dt, x).view(bs, 1)

        return (dx_dt, dlogp_x_dt)


def create_Xd(n, d, dim, device):
    x = torch.linspace(-3.0, 3.0, n, device=device)
    X = x.unsqueeze(1)
    for k in range(1, d):
        x_add = x.unsqueeze(1).repeat(1, n**k).reshape(n**(k+1), 1)
        X = torch.cat([X.repeat(n, 1), x_add], dim=1)
    zeros = torch.zeros(size=(n**d, dim-d), device=device)
    return torch.cat([X, zeros], dim=1)


class PINF(torch.nn.Module):
    def __init__(self, config):
        super(PINF, self).__init__()
        self.num_flows = config['num_flows']
        self.layers = config['layers']
        self.dim = config['dim']
        self.N_in = config['N_in']
        self.lr = config['lr']
        self.epoch = config['epoch']
        self.train_save_freq = config['train_save_freq']
        self.test_freq = config['test_freq']
        self.log_freq = config['log_freq']
        self.device = config['device']
        self.path = config['path']
        
        self.p_test = None
        
        self.model = ODEs(self.num_flows, self.layers, self.dim, self.device)
        self.opt_Adam = torch.optim.Adam(params=self.model.parameters(), lr=self.lr)
        
        self.Epoch = [i * self.train_save_freq for i in range(int(self.epoch / self.train_save_freq) + 1)]
        self.Epoch_test = [i * self.test_freq for i in range(int(self.epoch / self.test_freq) + 1)]

        self.loss = torch.nn.MSELoss()
        self.test_p_net = {'mae':[], 'mape':[], 'mse':[]}
        self.Loss = []

    def p_true(self, x):
        xn = torch.norm(x, dim=1)
        p = (a/(torch.pi*s*s))**(self.dim/2) * torch.exp(-a*xn**2/(s*s))
        return p
    
    def p_pred(self, x):
        return torch.exp(self.model.phi(x))
    
    def train_PINF(self):
        print("Strat training!")
        start = time.time()
        for it in range(self.epoch+1):            
            # Test
            if it % self.test_freq == 0:
                self.test()
            
            # Train
            train_start = time.time()
            train_loss = self.train_model()
            if it % self.train_save_freq == 0:
                self.Loss.append(train_loss)
            train_iteration_time = time.time() - train_start
            
            # Print
            if it % self.log_freq == 0:
                print('It: %d, Time: %.2f, Loss: %.3e' % (it, train_iteration_time * self.log_freq, train_loss))

            # Plot
            if it % 50 == 0:
                self.plot_fig(it)

        elapsed = time.time() - start
        print('Training complete! Total time: %.2f h' % (elapsed/3600))
        
    def train_model(self):
        self.opt_Adam.zero_grad()
        
        # sampling
        z = self.model.phi.pG.sample([self.N_in]).to(self.device)
        x0, logp_x0 = self.model.phi.inverse(z)
        
        logp_ode, logp_net = self.model.predict(x0, logp_x0)     
        loss = self.loss(logp_ode, logp_net)
        
        loss.backward()
        self.opt_Adam.step()

        return loss.item()

    def test(self):
        # test set
        if self.p_test is None:
            if self.dim == 1:
                self.x = torch.linspace(-5,5,50, device=self.device)
            else:
                self.x = create_Xd(50, 2, self.dim, self.device)            
            self.p_test = self.p_true(self.x)
        
        p_pred = self.p_pred(self.x)
        mae = torch.abs(p_pred - self.p_test).mean().item()
        mape = torch.abs((p_pred - self.p_test)/self.p_test).mean().item()
        mse = self.loss(p_pred, self.p_test).item()
        
        self.test_p_net['mae'].append(mae)
        self.test_p_net['mape'].append(mape)
        self.test_p_net['mse'].append(mse)
        print('Predict by p_net:  MAE: %.3e, MAPE: %.3e, MSE: %.3e' % (mae, mape, mse))

    def plot_fig(self, epoch):
        p_pred = self.p_pred(self.x)
        p_true = self.p_test
        mape = torch.abs(p_pred - p_true)/p_true

        X = self.x.cpu().detach().numpy()
        p_net_data = p_pred.cpu().detach().numpy()
        p_true_data = p_true.cpu().detach().numpy()
        mape_data = mape.cpu().detach().numpy()

        fig = plt.figure(figsize=(16, 4), dpi=200)
        #fig.suptitle('PINF')
        ax1 = fig.add_subplot(1, 3, 1, projection='3d')
        ax1.plot_surface(X[:,0].reshape(50,50), X[:,1].reshape(50,50), p_true_data.reshape(50,50), cmap='rainbow')
        ax1.set_title('Ground truth')
        
        ax2 = fig.add_subplot(1, 3, 2, projection='3d')
        ax2.plot_surface(X[:,0].reshape(50,50), X[:,1].reshape(50,50), p_net_data.reshape(50,50), cmap='rainbow')
        ax2.set_title('$p_{net}$')

        ax3 = fig.add_subplot(1, 3, 3)
        ax3.set_title('mape')
        cntr3 =ax3.contourf(X[:,0].reshape(50,50), X[:,1].reshape(50,50), mape_data.reshape(50,50), levels=np.linspace(0,np.max(mape_data),101), cmap='viridis')
        
        fig.colorbar(cntr3, ax=ax3, format="%.4f")
        ax3.set_xlabel('$x_1$')
        ax3.set_ylabel('$x_2$')
        ax3.set_aspect('equal')
        ax3.set_xticks(np.linspace(-3, 3, 7))
        ax3.set_yticks(np.linspace(-3, 3, 7))

        plt.savefig(self.path + "/pred_true_" + str(epoch) + ".png", bbox_inches = 'tight')
        plt.close()


if __name__ == "__main__":

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    a = 1.
    s = 1.
    
    # Hyperparameters
    dim = 50
    num_flows = 4
    layers = [dim, 16, 16, dim]
    N_in = 2000
    lr = 1e-2
    epoch = 500
    train_save_freq = 10
    test_freq = 50
    log_freq = 10
    
    path = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    if not os.path.exists(path):
        os.makedirs(path)
    
    config = {
        'dim': dim,
        'layers':layers,
        'num_flows':num_flows,
        'N_in': N_in,
        'lr': lr,
        'epoch': epoch,
        'train_save_freq': train_save_freq,
        'test_freq': test_freq,
        'log_freq': log_freq,
        'device': device,
        'path':path
    }

    model = PINF(config).to(device)
    model.train_PINF()

    # Save
    netname = 'PINF.pth'
    torch.save(model, path + '/' + netname)

    # Loss
    plt.figure(figsize=(8, 6))
    plt.title('Training loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.plot(model.Epoch, model.Loss)
    plt.savefig(path + '/Training loss.png')
    plt.close()

    # Error
    plt.figure(figsize=(8, 6))
    plt.title('Testing error')
    plt.xlabel('Epoch')
    plt.ylabel('Error')
    plt.yscale('log')
    plt.plot(model.Epoch_test, model.test_p_net['mae'], 'r', label='MAE')
    plt.plot(model.Epoch_test, model.test_p_net['mape'], 'g', label='MAPE')
    plt.legend()
    plt.savefig(path + '/Testing error.png')
    plt.close()
